package mymath;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.tools.data.FileHandler;

public class MyPCA {
    private double mean[];
    private MyMatrix vecX;
    private MyMatrix vecY;
    private Map<Integer, Integer> legalProp;
    private int instKeyCount = 8;

    /**
     * 从数据集建立主成分分析工具
     * 采用协方差矩阵特征向量作为基向量
     * 算法讲解：https://blog.csdn.net/ymengm/article/details/123862936
     */
    public MyPCA(Dataset dataset) {
        int col = dataset.size();
        Set<Entry<Integer, Double>> inst0 = dataset.instance(0).entrySet();
        legalProp = new HashMap<>();
        int row = 0;
        for (Entry<Integer, Double> pair : inst0) {
            if (!Double.isNaN(pair.getValue()) && pair.getValue() != 0.0) {
                legalProp.put(pair.getKey(), row++);
            }
        }
        MyMatrix mat = new MyMatrix(row, col);
        for (int i = 0; i < col; i++) {
            Set<Entry<Integer, Double>> insti = dataset.instance(i).entrySet();
            for (Entry<Integer, Double> pair : insti) {
                if (!Double.isNaN(pair.getValue()) && legalProp.containsKey(pair.getKey())) {
                    mat.mat[legalProp.get(pair.getKey())][i] = pair.getValue();
                }
            }
        }
        mean = new double[row];
        for (int i = 0; i < row; i++) {
            double sum = 0;
            for (int j = 0; j < col; j++) {
                sum += mat.mat[i][j];
            }
            mean[i] = sum / col;
            for (int j = 0; j < col; j++) {
                mat.mat[i][j] -= mean[i];
            }
        }
        MyMatrix mR = mat.multiple(mat.transpose()).multipleNumber(1.0 / col);
        MyMatrix[] vecs = mR.qrEgisAsRows().sliceAllRows();
        vecX = vecs[0];
        vecY = vecs[1];
        for (int i = 0; i < vecs.length; i++) {
            System.out.println(vecs[i]);
        }
    }

    private MyMatrix toMatrix(Instance inst) {
        int col = legalProp.size();
        MyMatrix mat = new MyMatrix(1, col);
        for (Entry<Integer, Double> pair : inst.entrySet()) {
            int key = pair.getKey();
            if (!Double.isNaN(pair.getValue()) && legalProp.containsKey(key)) {
                mat.mat[0][legalProp.get(key)] = pair.getValue() - mean[legalProp.get(key)];
            }
        }
        return mat;
    }
    /**
     * 从数据投影到二维平面，并且使得方差尽量小
     */
    public double[] resolveProjection(Instance inst) {
        MyMatrix mat = toMatrix(inst);
        double[] res = new double[2];
        res[0] = mat.dotMulAsRow(vecX);
        res[1] = mat.dotMulAsRow(vecY);
        return res;
    }

    private double dis(double[] a, double[] b) {
        return Math.sqrt((a[0] - b[0]) * (a[0] - b[0]) + (a[1] - b[1]) * (a[1] - b[1]));
    }
    
    /**
     * 从生成的平面点集，映射前后的测试数据集，反向建立能使投影覆盖二维平面而与数据兼容的高维模拟数据集
     */
    public Dataset fromPlain(double[][] plain, double[][] points, Dataset testSet, int[] indexes) {
        try {
            FileWriter fileWriter = new FileWriter("./plain.csv");
            for (int i = 0; i < plain.length; i++) {
                int index = 0;
                for (int j = 0; j < points.length; j++) {
                    if (dis(plain[i], points[j]) < dis(plain[i], points[index])) {
                        index = j;
                    }
                }
                MyMatrix vec = toMatrix(testSet.instance(indexes[index]))
                        .plus(vecX.multipleNumber(plain[i][0] - points[index][0]))
                        .plus(vecY.multipleNumber(plain[i][1] - points[index][1]));
                for (int j = 0; j < vec.mat[0].length; j++) {
                    vec.mat[0][j] += mean[j];
                }
                String[] instiStr = new String[instKeyCount + 1];
                instiStr[0] = "None";
                for (int j = 1; j < instiStr.length; j++) {
                    if (legalProp.containsKey(j - 1)) {
                        instiStr[j] = "" + vec.mat[0][legalProp.get(j - 1)];
                    } else {
                        instiStr[j] = "";
                    }
                }
                for (int j = 0; j < instiStr.length; j++) {
                    if (j > 0) {
                        fileWriter.write(",");
                    }
                    fileWriter.write(instiStr[j]);
                }
                fileWriter.write("\n");
            }
            fileWriter.close();
            Dataset dataset = FileHandler.loadDataset(new File("./plain.csv"), 0, ",");
            System.out.println("Plain dataset created.");
            return dataset;
        } catch (IOException e) {
            e.printStackTrace();
        }
        return null;
    }
}
